import torch
import torch.nn as nn
from torch import Tensor
import numpy as np
from torch_geometric.nn import GCNConv


class GraphProfiler(nn.Module):
    def __init__(
            self,
            num_src: int,
            num_dst: int,
            msg_dim: int,
            model_dim: int = 100,
            max_storage_len: int = 1000,
            max_num_edges: int = 1000,
            # profile_components: str = "dst-msg-t-y",
            history_aggregation: str = "max",
            time_encoding_version: str = "learn",
            readout_version: int = "both",
            gcn_add_self_loops: int = 0,
            gcn_improved: int = 1,
            gcn_aggr:  str = "max",
            device=None,
            **kwargs
    ):
        super().__init__()

        self.num_src = num_src
        self.num_dst = num_dst
        self.msg_dim = msg_dim
        self.model_dim = model_dim
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.device = device

        self.max_storage_len = max_storage_len
        self.max_num_edges = max_num_edges

        # self.profile_components = profile_components
        self.history_aggregation = history_aggregation

        self.storage = None
        self.edges = None

        self.dst_encoder = nn.Embedding(num_embeddings=self.num_dst + 1, embedding_dim=self.model_dim, padding_idx=0).requires_grad_(False).to(device)
        self.msg_encoder = nn.Linear(in_features=self.msg_dim, out_features=self.model_dim).to(device)
        self.time_encoder = TimeEncoder(out_features=self.model_dim, version=time_encoding_version, device=device).to(device)
        self.profile_lin = nn.Linear(in_features=3 * self.model_dim + 1, out_features=self.model_dim).to(device)
        self.conv = GCNConv(self.model_dim, self.model_dim, add_self_loops=bool(gcn_add_self_loops), improved=bool(gcn_improved), aggr=gcn_aggr).to(device)
        self.readout = Readout(in_features=self.model_dim, out_features=1, version=readout_version).to(device)

        self.reset_parameters()

    def reset_parameters(self):
        if hasattr(self, 'src_encoder'):
            self.src_encoder.reset_parameters()
        if hasattr(self, 'dst_encoder'):
            self.dst_encoder.reset_parameters()
        if hasattr(self, 'msg_encoder'):
            self.msg_encoder.reset_parameters()
        if hasattr(self, 'time_encoder'):
            self.time_encoder.reset_parameters()
        if hasattr(self, 'profile_lin'):
            self.profile_lin.reset_parameters()

    def reset_state(self):
        if hasattr(self, 'src_encoder'):
            self.src_encoder.reset_parameters()
        if hasattr(self, 'dst_encoder'):
            self.dst_encoder.reset_parameters()
        self.storage = None
        self.edges = None

    def detach(self):
        pass

    def forward(self, batch):
        batch = batch.to(self.device)

        """ Neighborhood Updates """
        n_id, edge_index, mask = self.neighbor_loader(batch)
        x = self.profile_encoder(n_id)
        x = self.conv(x, edge_index)
        src_emb = x[mask]

        """ Recent Event """
        dst_emb = self.dst_encoder(batch.dst)
        msg_emb = self.msg_encoder(batch.msg)
        time_emb = self.time_encoder(batch.t)

        """ Readout """
        logits = self.readout(src_emb, dst_emb, msg_emb, time_emb)
        y = batch.y.float().unsqueeze(-1)
        loss = self.criterion(logits, y)

        """ Updates """
        self.update_storage(batch)
        # self.update_profiles()

        y_pred = logits.detach().cpu().sigmoid()
        y_true = batch.y.detach().cpu().float().unsqueeze(-1)

        return loss, y_pred, y_true

    def neighbor_loader(self, batch):
        if self.storage is None:
            edge_index = torch.empty((2, 0), dtype=torch.long, device=self.device)
            mask = torch.arange(len(batch), device=self.device)
            return batch.src, edge_index, mask

        n_id, idx = torch.cat([self.edges[0, :], self.edges[1, :], batch.src]).unique(return_inverse=True, sorted=True)
        edge_index = idx[:torch.numel(self.edges)].view_as(self.edges)
        mask = idx[torch.numel(self.edges):]

        return n_id, edge_index, mask

    def profile_encoder(self, n_ids):
        def aggregate_history(z, idx):
            idx = idx.view(idx.size(0), 1).expand(-1, z.size(1))
            unique_idx, unique_count = idx.unique(dim=0, return_counts=True)
            if self.history_aggregation == "mean":
                res = torch.zeros_like(unique_idx, dtype=torch.float).scatter_add_(0, idx, z)
                res = res / unique_count.float().unsqueeze(1)
            elif self.history_aggregation == "max":
                res = torch.zeros_like(unique_idx, dtype=torch.float).scatter_reduce_(0, idx, z, reduce="amax")
                res = res / unique_count.float().unsqueeze(1)
            else:
                raise NotImplementedError
            return torch.tanh(res)
        if self.storage is None:
            return torch.zeros(n_ids.size(0), self.model_dim, dtype=torch.float, device=self.device)

        history = self.storage
        record_emb = history.y.unsqueeze(-1)
        with torch.no_grad():
            dst_emb = self.dst_encoder(history.dst)
            msg_emb = self.msg_encoder(history.msg)
            time_emb = self.time_encoder(history.t)
        history_emb = torch.cat((dst_emb, msg_emb, time_emb, record_emb), dim=1)
        src, src_idx = torch.cat((history.src, n_ids)).unique(return_inverse=True)
        dummy_emb = torch.zeros(n_ids.size(0), history_emb.size(1), dtype=torch.float, device=self.device)
        history_emb = torch.cat((history_emb, dummy_emb), dim=0)
        z = aggregate_history(history_emb, src_idx)
        z = self.profile_lin(z)
        return z[src_idx[history.src.size(0):]]

    # def update_profiles(self):
    #     def aggregate_history(z, idx):
    #         idx = idx.view(idx.size(0), 1).expand(-1, z.size(1))
    #         unique_idx, unique_count = idx.unique(dim=0, return_counts=True)
    #         res = torch.zeros_like(unique_idx, dtype=torch.float).scatter_add_(0, idx, z)
    #         res = res / unique_count.float().unsqueeze(1)
    #         return res
    #     if self.storage is None:
    #         return
    #     history = self.storage
    #     record_emb = history.y.unsqueeze(-1)
    #     msg_emb = self.msg_encoder(history.msg)
    #     time_emb = self.time_encoder(history.t)
    #     history_emb = torch.cat((msg_emb, time_emb, record_emb), dim=1)
    #     history_emb = self.profile_lin(history_emb)
    #     with torch.no_grad():
    #         src, src_idx = history.src.unique(return_inverse=True)
    #         self.src_encoder.weight[src] = aggregate_history(history_emb, src_idx)
    #         dst, dst_idx = history.dst.unique(return_inverse=True)
    #         self.dst_encoder.weight[dst] = aggregate_history(history_emb, dst_idx)
    #     return

    def update_storage(self, batch):
        if self.storage is None:
            self.storage = batch
        else:
            self.storage.src = torch.cat((self.storage.src, batch.src), 0)
            self.storage.dst = torch.cat((self.storage.dst, batch.dst), 0)
            self.storage.t = torch.cat((self.storage.t, batch.t), 0)
            self.storage.msg = torch.cat((self.storage.msg, batch.msg), 0)
            self.storage.y = torch.cat((self.storage.y, batch.y), 0)

        self.update_edges(batch)

        if len(self.storage) > self.max_storage_len:
            t0 = torch.topk(self.storage.t, self.max_storage_len)[0].min()
            self.storage = self.storage[self.storage.t >= t0]

        if self.edges.size(1) > self.max_num_edges:
            self.edges = self.edges[:self.max_num_edges]

    def update_edges(self, batch):
        if self.storage is None:
            return
        else:
            new_edges = torch.empty((2, 0), dtype=torch.long, device=self.device)
            for sample in batch:
                meta_connections = self.storage[self.storage.dst == sample.dst].src
                e = torch.stack((meta_connections, sample.src.expand(meta_connections.size(0))))
                new_edges = torch.cat((e, new_edges,), 1)
            if self.edges is None:
                self.edges = new_edges
            else:
                self.edges = torch.cat((new_edges, self.edges), 1)


class TimeEncoder(nn.Module):
    def __init__(self, out_features: int, device, version="fix"):
        super().__init__()
        self.out_features = out_features
        self.lin = nn.Linear(1, out_features)
        self.device = device
        self.version = version
    # TODO: device management here is not perfect but works for the moment

    def reset_parameters(self):
        if self.version == "fix":
            self.lin.weight = nn.Parameter(
                (torch.from_numpy(1 / 10 ** np.linspace(0, 9, self.out_features, dtype=np.float32))).to(self.device).reshape(self.out_features, -1))
            self.lin.bias = nn.Parameter(torch.zeros(self.out_features, device=self.device))
            self.lin.weight.requires_grad = False
            self.lin.bias.requires_grad = False
        else:
            self.lin.reset_parameters()

    def forward(self, t: Tensor) -> Tensor:
        if self.type == "fix":
            with torch.no_grad():
                z = self.lin(t.float().view(-1, 1)).cos()
        else:
            z = self.lin(t.float().view(-1, 1)).cos()
        return z


class Readout(nn.Module):
    def __init__(self, in_features: int, out_features: int = 1, version=None):
        super().__init__()
        self.version = version
        self.lin_msg = nn.Linear(in_features, in_features)
        if self.version == "src":
            self.lin_src = nn.Linear(in_features, in_features)
        elif self.version == "dst":
            self.lin_dst = nn.Linear(in_features, in_features)
        elif self.version == "both":
            self.lin_src = nn.Linear(in_features, in_features)
            self.lin_dst = nn.Linear(in_features, in_features)
        else:
            raise NotImplementedError

        self.lin_final = nn.Linear(in_features, out_features)

    def forward(self, z_src, z_dst, z_msg, z_t):
        h = self.lin_msg(z_msg)
        if hasattr(self, "lin_src"):
            h += self.lin_src(z_src)
        if hasattr(self, "lin_dst"):
            h += self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)





